package com.learn.security.oauth2.auth.server.config.oauth2;

import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

import javax.servlet.http.HttpServletRequest;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * HttpServletRequest -> OAuth2UsernamePasswordAuthenticationToken
 *
 * @author knight
 */
public class OAuth2PasswordAuthenticationConverter implements AuthenticationConverter {

	/**
	 * org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2EndpointUtils
	 * 得到参数
	 * @param request 请求
	 * @return {@link MultiValueMap<String, String>}
	 */
	static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
		Map<String, String[]> parameterMap = request.getParameterMap();
		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
		parameterMap.forEach((key, values) -> {
			if (values.length > 0) {
				for (String value : values) {
					parameters.add(key, value);
				}
			}
		});
		return parameters;
	}

	@Override
	public Authentication convert(HttpServletRequest request) {
		String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
		// 如果不是密码授权方式
		if (!StrUtil.equals(grantType, AuthorizationGrantType.PASSWORD.getValue())) {
			return null;
		}
		MultiValueMap<String, String> parameters = getParameters(request);
		String username = parameters.getFirst(OAuth2ParameterNames.USERNAME);
		// 如果用户名不存在或参数数量>1
		if (StrUtil.isBlank(username) || parameters.get(OAuth2ParameterNames.USERNAME).size() != 1) {
			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
					String.format("缺少: %s参数", OAuth2ParameterNames.USERNAME), StringPool.EMPTY));
		}
		// 如果密码不存在或参数数量>1
		String password = parameters.getFirst(OAuth2ParameterNames.PASSWORD);
		if (StrUtil.isBlank(password) || parameters.get(OAuth2ParameterNames.PASSWORD).size() != 1) {
			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
					String.format("缺少: %s参数", OAuth2ParameterNames.PASSWORD), StringPool.EMPTY));
		}
		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
		if (clientPrincipal == null) {
			throw new OAuth2AuthenticationException(
					new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, "缺少客户端参数", StringPool.EMPTY));
		}
		// 获取额外参数
		Map<String, Object> additionalParameters = parameters.entrySet().stream()
				.filter(entry -> !StrUtil.equals(entry.getKey(), OAuth2ParameterNames.USERNAME)
						&& !StrUtil.equals(entry.getKey(), OAuth2ParameterNames.PASSWORD))
				.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().get(0)));

		return new OAuth2PasswordAuthenticationToken(username, password, clientPrincipal, additionalParameters);
	}

}
